# Copyright 2022 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Generates enums.py.

The JSON input can be generated via:
    clang -Xclang -ast-dump=json -fsyntax-only -fparse-all-comments mujoco.h
"""

import json
from typing import Any, Mapping, Sequence

from absl import app
from absl import flags

from google3.third_party.mujoco.introspect import ast_nodes
from google3.third_party.mujoco.introspect.codegen import formatter

FLAGS = flags.FLAGS
flags.DEFINE_string(
    'json_path', None,
    'Path to the JSON file representing the Clang AST for mujoco.h')

ClangJsonNode = Mapping[str, Any]


def traverse(node, visitor):
  visitor.visit(node)
  children = node.get('inner', [])
  for child in children:
    traverse(child, visitor)


class MjEnumVisitor:
  """A Clang AST JSON node visitor for MuJoCo API enum declarations."""

  def __init__(self):
    self._enums = {}
    self._typedefs = {}

  def _make_enum(self, node: ClangJsonNode) -> ast_nodes.EnumDecl:
    """Makes a EnumDecl from a Clang AST EnumDecl node."""
    name = f"enum {node['name']}"
    values = []
    for child in node['inner']:
      child_kind = child.get('kind')
      if child_kind == 'EnumConstantDecl':
        next_idx = values[-1][1] + 1 if values else 0
        value = int(child['inner'][0].get('value', next_idx))
        values.append((child['name'], value))
    return ast_nodes.EnumDecl(name=name, declname=name, values=dict(values))

  def visit(self, node: ClangJsonNode) -> None:
    if (node.get('kind') == 'EnumDecl' and
        node.get('name', '').startswith('mj')):
      enum_decl = self._make_enum(node)
      self._enums[enum_decl.name] = enum_decl
    elif (node.get('kind') == 'TypedefDecl' and
          node['type']['qualType'].startswith('enum mj')):
      enum = self._enums[node['type']['qualType']]
      self._typedefs[node['name']] = ast_nodes.EnumDecl(
          name=node['name'], declname=enum.declname, values=dict(enum.values))

  @property
  def enums(self) -> Mapping[str, ast_nodes.EnumDecl]:
    return self._enums

  @property
  def typedefs(self) -> Mapping[str, ast_nodes.EnumDecl]:
    return self._typedefs


def main(argv: Sequence[str]) -> None:
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  with open(FLAGS.json_path, 'r', encoding='utf-8') as f:
    root = json.load(f)

  visitor = MjEnumVisitor()

  traverse(root, visitor)

  enums_str = formatter.format_as_python_code(visitor.typedefs)

  print(f'''
# Copyright 2022 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides information about MuJoCo API enums.

DO NOT EDIT. THIS FILE IS AUTOMATICALLY GENERATED.
"""

from typing import Mapping

from google3.third_party.mujoco.introspect.ast_nodes import EnumDecl

ENUMS: Mapping[str, EnumDecl] = {enums_str}
'''.strip())  # `print` adds a trailing newline


if __name__ == '__main__':
  app.run(main)
